import os

import cv2
import numpy as np
from cnocr import CnOcr

import utils
from src import my_vgg


def choice_question_recognition(image, model):
    """
    获取每个选择题对应选中的选项
    :param model: 训练好的模型
    :param image: 选择题原图
    :return answers_dict.items(): 一个字典的items，题号对应的答案的items
    """
    question_nums_and_choices = match_question_num_and_choices(image)
    answers_dict = {}
    for key in question_nums_and_choices.keys():
        choices = question_nums_and_choices[key]
        # utils.show_image(choices)
        answers = get_selected_answer(choices, model)
        answers_dict[key] = answers
    return answers_dict.items()


def match_question_num_and_choices(image):
    """
    获得每个选择题对应的图像
    :param image:一列选择题的图片
    :return question_num_and_choices: 一个字典，key是题号，value是题号对应的选项的图片
    """
    result_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    _, result_image = cv2.threshold(result_image, -1, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
    result_image = cv2.GaussianBlur(result_image, (5, 5), 0)
    kernel = np.ones((5, 5), np.uint8)
    result_image = cv2.morphologyEx(result_image, cv2.MORPH_CLOSE, kernel, iterations=3)

    # utils.show_image(result_image)
    contours, _ = cv2.findContours(result_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    # res = cv2.drawContours(image.copy(), contours, -1, (0, 0, 255), 2)
    # utils.show_image(res)
    question_nums_dict = {}
    choices_list = []
    ocr = CnOcr()
    # i = 0
    for contour in contours:
        box = cv2.boundingRect(contour)
        (x, y, w, h) = box
        x -= 2
        x = max(0, x)
        w += 4
        box = (x, y, w, h)
        temp_image = utils.get_image_by_box(image, box)
        # utils.show_image(temp_image)
        text = utils.get_ocr_text_by_ocr(temp_image, ocr)
        question_num = ''.join(text.get('text'))
        question_num = question_num.replace('.', '')
        # question_num = text.replace('.', '')
        if question_num.isdigit():
            question_nums_dict[int(question_num)] = box
            # cv2.imwrite("../pic/question_num/{}.jpg".format(i), temp_image)
            # i += 1
        else:
            choices_list.append(contour)

    question_num_and_choices = {}
    for key in question_nums_dict.keys():
        num_center_x, num_center_y = utils.get_center_of_box(question_nums_dict[key])
        choices_contours = []

        for index, choices in enumerate(choices_list):
            (choices_x, choices_y, choices_w, choices_h) = cv2.boundingRect(choices)
            if choices_y < num_center_y < choices_y + choices_h:
                choices_contours.append(choices)
        if len(choices_contours) > 0:
            box = cv2.boundingRect(np.vstack(choices_contours))
            (x, y, w, h) = box
            if h < utils.selected_box_h + 3:
                h = utils.selected_box_h + 3
            box = (x, y, w, h)
            question_num_and_choices[key] = utils.get_image_by_box(image, box)

    return question_num_and_choices


def get_selected_boxes(image):
    """
    识别图像中的被选中的轮廓
    :param image: 每一题对应的填涂区域
    :return selected_boxes: 被选中的区域，用矩形框表示
    """
    kernel = np.ones((5, 5), np.uint8)
    result_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    _, result_image = cv2.threshold(result_image, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
    # utils.show_image(result_image)
    result_image = cv2.morphologyEx(result_image, cv2.MORPH_ERODE, kernel, iterations=1)
    # utils.show_image(result_image)
    result_image = cv2.morphologyEx(result_image, cv2.MORPH_DILATE, kernel, iterations=1)
    result_image = cv2.GaussianBlur(result_image, (5, 5), 0)
    result_image = cv2.morphologyEx(result_image, cv2.MORPH_DILATE, kernel, iterations=1)
    result_image = cv2.morphologyEx(result_image, cv2.MORPH_ERODE, kernel, iterations=1)
    result_image = cv2.GaussianBlur(result_image, (5, 5), 0)
    _, result_image = cv2.threshold(result_image, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    # utils.show_image(result_image)

    # 提取答案的轮廓
    answer_contours, _ = cv2.findContours(result_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    selected_boxes = []
    for contour in answer_contours:
        box = cv2.boundingRect(contour)
        (x, y, w, h) = box
        diff_w = utils.selected_box_w - w
        diff_h = utils.selected_box_h - h
        if diff_w > 0:
            x -= int(diff_w / 2)
            x = max(0, x)
            w = utils.selected_box_w
        if diff_h > 0:
            y -= int(diff_h / 2)
            y = max(0, y)
            h = utils.selected_box_h
        box = (x, y, w, h)
        selected_boxes.append(box)
    #     cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), 2)
    # utils.show_image(image)
    return selected_boxes


def get_choices_template(dir_path: str):
    """
    返回选项模板
    :param dir_path: 模板的路径
    :return: 所有模板的图片
    """
    names = os.listdir(dir_path)
    choices = []
    for name in names:
        path = dir_path + '/' + name
        choice = {'image': cv2.imread(path), 'name': name.split('.')[0]}
        choices.append(choice)
    return choices


def get_unselected_box(image):
    """
    使用模板匹配，找出未被选中的选项
    :param image: 每一题对应的填涂区域
    :return unselected_boxes: 未选中的区域，用矩形框表示
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    guass = cv2.GaussianBlur(gray, (5, 5), 1)
    # utils.show_image(guass)
    unselected_boxes = []
    templates = get_choices_template("../choices/unselected")
    # image_copy = image.copy()
    for template in templates:
        template_gray = cv2.cvtColor(template['image'], cv2.COLOR_BGR2GRAY)
        template_blur = cv2.GaussianBlur(template_gray, (5, 5), 1)
        # utils.show_image(template_blur)
        # h, w = template_blur.shape[:2]
        res = cv2.matchTemplate(guass, template_blur, cv2.TM_CCOEFF_NORMED)
        minVal, maxVal, minLoc, maxLoc = cv2.minMaxLoc(res)
        box = (maxLoc[0], maxLoc[1], template_blur.shape[1], template_blur.shape[0])
        unselected_boxes.append(box)
        # temp_image = utils.get_image_by_box(image, box)
        # cv2.imwrite("../dataset/unselected/{}.jpg".format(utils.unselected), temp_image)
        # utils.unselected += 1
    return unselected_boxes


def get_selected_answer(image, model):
    """
    获取每一题选择的答案
    :param image: 每一题对应的填涂区域
    :param model: 自己训练的模型
    :return answers: 被选中的答案列表
    """
    # 被选中的区域
    selected_boxes = get_selected_boxes(image)
    # 未被选中的区域
    unselected_boxes = get_unselected_box(image)
    # 所有选项的列表
    choices_boxes = selected_boxes + unselected_boxes
    # 所有选项按从左到右的顺序排列
    choices_boxes = utils.sort_boxes(choices_boxes)
    choices = utils.get_less_cover_boxes(choices_boxes)
    choices_images = [utils.get_image_by_box(image, b) for b in choices]
    answers = []
    # 判读选项是否被选中
    for i, choices_image in enumerate(choices_images):
        # utils.show_image(choices_image)
        # cv2.imwrite("../pic/choices/{}.jpg".format(i), choices_image)
        index, _ = my_vgg.judge_by_model(choices_image, model)
        if index is not utils.label_dict['unselected']:
            answers.append(chr(ord('A') + i))
    return answers


if __name__ == "__main__":
    image = cv2.imread("../pic/choices_question/1.jpg")
    model = my_vgg.get_trained_my_vgg16()
    answers = choice_question_recognition(image, model)
    # answer = get_unselected_box(image)
    answers = sorted(answers, key=lambda x: x[0])
    print(answers)
    print(type(answers))
    # image = cv2.imread("../pic/question_num/9.jpg")
    # ocr = CnOcr()
    # text = utils.get_ocr_text_by_ocr(image, ocr)
    # print(text)
